import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import pandas as pd
from pathos.multiprocessing import Pool

rcParams['figure.figsize'] = (13.0, 7.0)
rcParams['lines.linewidth'] = 3
rcParams['font.size'] = 18

from jm.library.data_helper import filter_data
from jm.library.likelihood import Likelihood

from jm.library.additional_helpers import plot_over_time, payment_cols, priority_cols, action_cols, \
    relative_payment_cols, fwdiff_payment_cols, relative_fwdiff_payment_cols, get_cols_over_time


def is_ever(df, cols, value):
    global REF_DF 
    global REF_COLS 
    global REF_VALUE 
    
    REF_DF = df
    REF_COLS = cols
    REF_VALUE = value

    with Pool() as pool:
        list_is = pool.map(
            lambda i: REF_DF.loc[:, REF_COLS[:i+1]].apply(lambda r: np.isin(REF_VALUE, r), axis=1, raw=True), 
            list(range(len(REF_COLS))))
    this_df = pd.concat(list_is, axis=1)
    this_df.columns = Likelihood.event_dates
    return 1 * this_df


df_sim = pd.read_csv('estimation/simulation_estimates/longer_G1_more_G1_actual_estimates.csv')
df_sim_control = pd.read_csv('estimation/simulation_estimates/control_config_control_total_due_estimates.csv')

df_status = pd.read_csv('data/jesus_maria_status_deidentified.csv')

df_status[relative_payment_cols] = df_status[payment_cols].divide(
                df_status['total_due'], axis=0)


df_status[relative_fwdiff_payment_cols] = df_status[fwdiff_payment_cols].divide(
    df_status['total_due'], axis=0)

df_treatment = filter_data(df_status, assignment_to_treatment=1)
df_control = filter_data(df_status, assignment_to_treatment=0)

df_treatment = df_treatment.sort_values(
    'effective_score', ascending=False, kind='mergesort')
state_0_selected_cols = ['total_due'] + [payment_cols[0], priority_cols[0], action_cols[0]] + [
    'prob_repayment_endo_covariates']


# In[54]:


base_res_relpayments = pd.DataFrame(df_sim[get_cols_over_time('relative_payments_by_{}')])
base_res_relpayments.columns = Likelihood.event_dates

base_res_totalpayments = pd.DataFrame(df_sim[get_cols_over_time('payments_by_{}')])
base_res_totalpayments.columns = Likelihood.event_dates

base_res_inG1 = pd.DataFrame(df_sim[get_cols_over_time('G1_priority_by_{}')])
base_res_inG1.columns = Likelihood.event_dates

base_res_inmedida = pd.DataFrame(df_sim[get_cols_over_time('medida_action_by_{}')])
base_res_inmedida.columns = Likelihood.event_dates

base_res_everrec1 = pd.DataFrame(df_sim[get_cols_over_time('rec1_action_by_{}')])
base_res_everrec1.columns = Likelihood.event_dates
base_res_everrec1.mean()

base_res_evervalor = pd.DataFrame(df_sim[get_cols_over_time('valor_action_by_{}')])
base_res_evervalor.columns = Likelihood.event_dates


# In[45]:


fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3)

plot_over_time(df_treatment, ax=ax1, normalize=1e6)
base_res_totalpayments = base_res_totalpayments/1e6
base_res_totalpayments.mean().plot(ax=ax1)
ax1.set_title('cumulative taxes collected')
ax1.set_ylabel('M soles')
ax1.legend(['actual', 'sim'], loc="upper left", fontsize=13)

plot_over_time(df_treatment, col_fmt='relative_payments_by_{}', func=np.mean, ax=ax2)
base_res_relpayments.mean().plot(ax=ax2)
ax2.set_title('mean relative payment')
ax2.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_treatment, priority_cols, 'G1').sum(axis=0).plot(ax=ax3)
base_res_inG1.mean().plot(ax=ax3)
ax3.set_title('number with G1 priority')
ax3.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_treatment, action_cols, 'valor').sum(axis=0).plot(ax=ax4)
base_res_evervalor.mean().plot(ax=ax4)
ax4.set_title('number with notifications')
ax4.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_treatment, action_cols, 'medida').sum(axis=0).plot(ax=ax5)
base_res_inmedida.mean().plot(ax=ax5)
ax5.set_title('number of garnishments')
ax5.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_treatment, action_cols, 'rec1').sum(axis=0).plot(ax=ax6)
base_res_everrec1.mean().plot(ax=ax6)
ax6.set_title('number of writs')
ax6.legend(['actual', 'sim'], loc="upper left", fontsize=13)

fig.tight_layout()

plt.savefig('figs/figOB1_longversion.pdf')


# In[55]:


control_res_relpayments = pd.DataFrame(df_sim_control[get_cols_over_time('relative_payments_by_{}')])
control_res_relpayments.columns = Likelihood.event_dates

control_res_totalpayments = pd.DataFrame(df_sim_control[get_cols_over_time('payments_by_{}')])
control_res_totalpayments.columns = Likelihood.event_dates

control_res_inmedida = pd.DataFrame(df_sim_control[get_cols_over_time('medida_action_by_{}')])
control_res_inmedida.columns = Likelihood.event_dates

control_res_everrec1 = pd.DataFrame(df_sim_control[get_cols_over_time('rec1_action_by_{}')])
control_res_everrec1.columns = Likelihood.event_dates

control_res_evervalor = pd.DataFrame(df_sim_control[get_cols_over_time('valor_action_by_{}')])
control_res_evervalor.columns = Likelihood.event_dates


# In[56]:


fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3)
fig.delaxes(ax3)

plot_over_time(df_control, ax=ax1, normalize=1e6)
control_res_totalpayments = control_res_totalpayments/1e6
control_res_totalpayments.mean().plot(ax=ax1)
ax1.set_title('cumulative taxes collected')
ax1.set_ylabel('M soles')
ax1.legend(['actual', 'sim'], loc="upper left", fontsize=13)

plot_over_time(df_control, col_fmt='relative_payments_by_{}', func=np.mean, ax=ax2)
control_res_relpayments.mean().plot(ax=ax2)
ax2.set_title('mean relative payment')
ax2.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_control, action_cols, 'valor').sum(axis=0).plot(ax=ax4)
control_res_evervalor.mean().plot(ax=ax4)
ax4.set_title('number with notifications')
ax4.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_control, action_cols, 'medida').sum(axis=0).plot(ax=ax5)
control_res_inmedida.mean().plot(ax=ax5)
ax5.set_title('number of garnishments')
ax5.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_control, action_cols, 'rec1').sum(axis=0).plot(ax=ax6)
control_res_everrec1.mean().plot(ax=ax6)
ax6.set_title('number of writs')
ax6.legend(['actual', 'sim'], loc="upper left", fontsize=13)

fig.tight_layout()

plt.savefig('figs/figOB2_longversion.pdf')


# In[ ]:




